from __future__ import print_function

import os
import math
import torch.nn as nn
import torch
import numpy as np
from .utils.utils import get_activation_function, make_conv_block, make_fc_block


default_cfg = {
    '5': [64, 64],
    '7': [64, 64, 'M', 128, 128],
    '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    '16AM': [64, 64, 'AM', 128, 128, 'AM', 256, 256, 256, 'AM', 512, 512, 512, 'AM', 512, 512, 512],
    '16A': [64, 64, 'A', 128, 128, 'A', 256, 256, 256, 'A', 512, 512, 512, 'A', 512, 512, 512],
    '16S' : [64, 64, 'S', 128, 128, 'S', 256, 256, 256, 'S', 512, 512, 512, 'S', 512, 512, 512],
    '16_long_channel_remain': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 512, 512, 512, 512, 512, 512, 'M', 'M'],
    '16_long_channel_reduce': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 256, 256, 256, 256, 256, 'M', 'M'],
    '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],

    '16:32': [32, 32, 'M', 32, 32, 'M', 32, 32, 32, 'M', 32, 32, 32, 'M', 32, 32, 32],
    '16:64': [64, 64, 'M', 64, 64, 'M', 64, 64, 64, 'M', 64, 64, 64, 'M', 64, 64, 64],
    '16:128': [128, 128, 'M', 128, 128, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M', 128, 128, 128],
    '16:256': [256, 256, 'M', 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256],
    '16:512': [512, 512, 'M', 512, 512, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M', 512, 512, 512],
}


class VGG(nn.Module):
    def __init__(self, activation_type, num_classes=10, depth=16, s_channel_ratio=1.0,
                 oper_order='cba', dataset='cifar', xavier_init=False, cut_block=None, bn_momentum=0.1,
                 designated_cut_index=None, designated_cut_acti_index=None, shallow_fc=False):
        super(VGG, self).__init__()
        self.designated_cut_acti_index = designated_cut_acti_index
        self.shallow_fc = shallow_fc
        cfg = default_cfg[depth]

        if designated_cut_index is not None:
            print("indices of list should be ascending")
            designated_cut_index.reverse()

            for ind in designated_cut_index:
                del cfg[ind]

        if cut_block is not None:
            print("Only for Max pooling! "*10)
            cut = int(cut_block)
            conv_block_cnt = 0
            M_cnt = 0


            for c in cfg[::-1]:
                if c != 'M':
                    conv_block_cnt += 1
                else:
                    M_cnt += 1

                if conv_block_cnt == cut:
                    break
            cut += M_cnt

            cfg = cfg[:-cut]
            cfg = cfg + ['M' for i in range(M_cnt)]

        self.saturated_block_indices = [4, 5, 7, 8, 9, 11]
        self.s_channel_ratio = s_channel_ratio
        self.bn_momentum = bn_momentum

        self.activation_generator = get_activation_function(activation_type)
        self.oper_order = oper_order
        self.dataset = dataset
        self.oper_order = {'full': list(oper_order)[:], 'front2': list(oper_order)[:-1]}

        self.cfg = cfg
        if 'ImageNet' == dataset or 'cub200' in dataset or 'flower102' in dataset:
            if depth == '16AM':
                self.cfg = self.cfg + ['AM']
            elif depth == '16A':
                self.cfg = self.cfg + ['A']
            elif depth == '16':
                self.cfg = self.cfg + ['M']

        self.feature = self.make_conv_layers()
        self.avgpool, self.classifier = self.make_fc_layers(num_classes)

        self._initialize_weights(xavier_init)

    def make_conv_layers(self):
        layers = []
        in_channels = 3
        for block_ind, v in enumerate(self.cfg):
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'AM':
                layers += [AbsMaxPool(kernel_size=2, stride=2)]
            elif v == 'A':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            elif v == 'S':
                pass
            else:
                if block_ind+1 in self.saturated_block_indices and self.s_channel_ratio != 1:
                    print("Channel change in saturated block")
                    saturated_block_channels = int(v * self.s_channel_ratio)
                    layers += [make_conv_block(in_channels, saturated_block_channels, self.activation_generator,
                                               kernel_size=3, stride=1, padding=1,
                                               oper_order=self.oper_order['full'], bn_momentum=self.bn_momentum)]
                    in_channels = saturated_block_channels
                else:
                    if block_ind > 0 and self.cfg[block_ind-1] == 'S':
                        stride = 2
                        padding = 0
                    else:
                        stride = 1
                        padding = 1

                    if self.designated_cut_acti_index is not None and block_ind in self.designated_cut_acti_index:
                        block_operation_list = self.oper_order['full'].copy()
                        block_operation_list.remove('a')
                    else:
                        block_operation_list = self.oper_order['full']

                    layers += [make_conv_block(in_channels, v, self.activation_generator, kernel_size=3, stride=stride,
                                               padding=1, oper_order=block_operation_list,
                                               bn_momentum=self.bn_momentum)]
                    in_channels = v

        return nn.Sequential(*layers)

    def make_fc_layers(self, out_classes):
        layers = []
        last_channel = [c for c in self.cfg if (c != 'M' and c != 'AM' and c != 'S')][-1]
        if last_channel == 64: # VGG-5(actually 4)
            last_channel = 16384
        elif last_channel == 128: #VGG-7(actually 6)
            last_channel = 8192

        if 'cifar' in self.dataset:
            avgpool = nn.AvgPool2d(2)

            if self.shallow_fc:
                layers = [nn.Linear(512, out_classes)]
            else:
                layers = [make_fc_block(activations_in=last_channel, activations_out=512,
                                       activation_generator=self.activation_generator,
                                       oper_order=self.oper_order['full'],
                                       bn_momentum=self.bn_momentum)]

                layers += [make_fc_block(activations_in=512, activations_out=out_classes,
                                       activation_generator=self.activation_generator,
                                       oper_order='f',
                                       bn_momentum=self.bn_momentum)]

        elif 'tinyImageNet' in self.dataset:
            avgpool = nn.MaxPool2d(kernel_size=2, stride=2)
            dim = last_channel * 2 * 2

            layers += [make_fc_block(activations_in=dim, activations_out=dim//2,
                                    activation_generator=self.activation_generator,
                                    oper_order="faD", bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=dim//2, activations_out=dim//2,
                                    activation_generator=self.activation_generator,
                                    oper_order="faD", bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=dim//2, activations_out=out_classes,
                                    activation_generator=self.activation_generator,
                                    oper_order="f", bn_momentum=self.bn_momentum)]

        elif 'ImageNet' in self.dataset or 'cub200' in self.dataset or 'flower102' in self.dataset:
            avgpool = nn.AdaptiveAvgPool2d((7, 7))

            layers += [make_fc_block(activations_in=last_channel * 7 * 7, activations_out=4096,
                                    activation_generator=self.activation_generator,
                                    oper_order="faD", bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=4096, activations_out=4096,
                                    activation_generator=self.activation_generator,
                                    oper_order="faD", bn_momentum=self.bn_momentum)]
            layers += [make_fc_block(activations_in=4096, activations_out=out_classes,
                                    activation_generator=self.activation_generator,
                                    oper_order="f", bn_momentum=self.bn_momentum)]
            #layers.add_module('linear', nn.Linear(4096, out_classes))

        return avgpool, nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)

        if self.avgpool is not None:
            x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)

        return y

    def get_minmax(self, x, block_output=False, channel_flag=False):
        minmax = []
        def _get_channel_minmax(x):
            cbhw = x.transpose(0, 1)
            flatten_channel = cbhw.reshape(cbhw.size(0), -1)

            c_max, _ = torch.max(flatten_channel, dim=1)
            c_min, _ = torch.min(flatten_channel, dim=1)

            channel_minmax = torch.stack([c_min, c_max], dim=1)

            return channel_minmax
        
        def _get_layer_minmax(x):
            layer_minmax = torch.stack([x.min(), x.max()], dim=0)
            print(layer_minmax.shape)

            return layer_minmax

        if block_output:
            for idx, module in enumerate(self.feature):
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)
                else:
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)
            for idx, module in enumerate(self.classifier):
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)
                else:
                    x = module(x)
                '''    
                x = module(x)

                if (idx+1) % len(self.oper_order['full']) == 0:
                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)
                '''

        # get activation function's output
        else:
            acti_type = self.activation_generator.__next__()
            operation_cnt = 0

            for idx, module in enumerate(self.feature.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)
                elif not isinstance(module, nn.Sequential):
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for idx, module in enumerate(self.classifier.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)

                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)

                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)

                        minmax.append(layer_minmax)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)

        return minmax

    def get_activation(self, x, target='activation'): #block_output=False):
        features = []

        if target == 'weights':
            for idx, module in enumerate(self.feature.modules()):

                if isinstance(module, nn.Conv2d):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for idx, module in enumerate(self.classifier.modules()):
                if isinstance(module, nn.Linear):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)
        # get activation function's output
        elif target == 'activation':
            acti_type = self.activation_generator.__next__()

            for idx, module in enumerate(self.feature.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for idx, module in enumerate(self.classifier.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    features.append(x)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)
        # get block output
        else:
            for idx, module in enumerate(self.feature):
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    features.append(x)
                else:
                    x = module(x)

            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            for idx, module in enumerate(self.classifier):
                if isinstance(module, nn.Sequential):
                    x = module(x)

                    features.append(x)
                else:
                    x = module(x)
                '''
                x = module(x)

                if (idx+1) % len(self.oper_order['full']) == 0:
                    features.append(x)
                '''

        return features

    def _initialize_weights(self, xavier_init):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
                    m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()

    def get_decompose_activation(self, x):
        acti_type = self.activation_generator.__next__()
        operation_cnt = 0
        outputs = []

        for idx, module in enumerate(self.feature.modules()):
            if isinstance(module, nn.Conv2d):
                x = module(x)
                outputs.append(x)
                operation_cnt += 1

            elif isinstance(module, type(acti_type)):
                x = module(x)
                outputs.append(x)
                operation_cnt += 1

            elif isinstance(module, nn.BatchNorm2d):
                state_dict = module.state_dict()

                running_mean = state_dict['running_mean'][None, :, None, None]
                running_var = state_dict['running_var'][None, :, None, None]
                eps = 1e-5

                gamma = state_dict['weight'][None, :, None, None]
                beta = state_dict['bias'][None, :, None, None]

                x_zero_mean = x - running_mean
                operation_cnt += 1

                x_std_nor = x_zero_mean / (torch.sqrt(running_var) + eps)
                outputs.append(x_std_nor)
                operation_cnt += 1

                x_scaled = x_std_nor * gamma
                operation_cnt += 1

                x_shifted = x_scaled + beta
                outputs.append(x_shifted)
                operation_cnt += 1

                x = x_shifted
            elif not isinstance(module, nn.Sequential):
                x = module(x)
        
        if self.avgpool is not None:
            x = self.avgpool(x)
            operation_cnt += 1

        x = x.view(x.size(0), -1)
        
        for idx, module in enumerate(self.classifier.modules()):

            if isinstance(module, nn.Linear):
                x = module(x)
                outputs.append(x)
                operation_cnt += 1

            elif isinstance(module, type(acti_type)):
                x = module(x)
                outputs.append(x)
                operation_cnt += 1

            elif isinstance(module, nn.BatchNorm1d):
                state_dict = module.state_dict()

                running_mean = state_dict['running_mean'][None, :]
                running_var = state_dict['running_var'][None, :]
                eps = 1e-5

                gamma = state_dict['weight'][None, :]
                beta = state_dict['bias'][None, :]

                x_zero_mean = x - running_mean
                operation_cnt += 1

                x_std_nor = x_zero_mean / (torch.sqrt(running_var) + eps)
                outputs.append(x_std_nor)
                operation_cnt += 1

                x_scaled = x_std_nor * gamma
                operation_cnt += 1

                x_shifted = x_scaled + beta
                outputs.append(x_shifted)
                operation_cnt += 1

                x = x_shifted
            elif not isinstance(module, nn.Sequential):
                x = module(x)

        return outputs
